#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Sep 29 10:04:23 2025

@author: liangmeijing

ml_tabular.py — Train a simple neural network on genotype and phenotype data.
use Keras models for both classification and regression.

Auto-detect regression vs classification based on phenotype data type.
"""
import sys
print("[DEBUG] Running script from:", sys.argv[0])

import argparse
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import r2_score, accuracy_score
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense, Dropout
import os
from pathlib import Path
# from sklearn.preprocessing import StandardScaler

print("[1/8] Starting ml_tabular.py ...")

# ---------- Argument parsing ----------
parser = argparse.ArgumentParser(description="Train a neural network for genotype–phenotype prediction.")
parser.add_argument("--genotype", required=True, help="Path to genotype file")
parser.add_argument("--phenotype", required=True, help="Path to phenotype file")
parser.add_argument("--id-col", required=True, help="Column name for sample IDs")
parser.add_argument("--target", required=True, help="Trait column name to predict")
parser.add_argument("--task", choices=["classify", "regress"], required=False)
parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs (default: 50)")
parser.add_argument("--batch-size", type=int, default=32, help="Batch size during training (default: 32)")
parser.add_argument("--dropout", type=float, default=0.3, help="Dropout rate (default: 0.3)")
args = parser.parse_args()


print("[2/8] Parsed command-line arguments")

# ---------- Load data ----------
print(f"[3/8] Loading genotype data: {args.genotype}")
geno = pd.read_csv(args.genotype, sep=None, engine="python")
print(f"      Genotype shape: {geno.shape}")

print(f"[4/8] Loading phenotype data: {args.phenotype}")
pheno = pd.read_csv(args.phenotype, sep=None, engine="python")
print(f"      Phenotype shape: {pheno.shape}")

# ---------- Merge data ----------
print(f"[5/8] Merging genotype & phenotype on ID column: {args.id_col}")
geno.columns = geno.columns.str.strip()
pheno.columns = pheno.columns.str.strip()
data = pd.merge(geno, pheno, on=args.id_col)
data.columns = data.columns.str.strip()

print(f"      Data shape after merge: {data.shape}", flush=True)
print(f"      Columns after merge: {list(data.columns)[:10]} ...", flush=True)

# ---------- Verify target column ----------
if args.target not in data.columns:
    print(f"❌ ERROR: Target column '{args.target}' not found in merged data.")
    print("Available columns include:")
    for col in data.columns:
        print(f"  - {col}")
    raise SystemExit(1)

# ---------- Split X / y ----------
print("[6/8] Splitting into train/test sets")
# Drop rows where the target is missing
data = data.dropna(subset=[args.target])
X = data.drop(columns=[args.target])
y = data[args.target]

print(f"      Dropped rows with missing target. Remaining samples: {len(data)}", flush=True)


# ---------- Determine task type automatically ----------
# Try converting y to numeric
numeric_y = pd.to_numeric(y, errors="coerce")

# Heuristic: if most values are numeric → regression, else → classification
non_null_ratio = numeric_y.notnull().mean()
unique_count = len(y.unique())

if non_null_ratio > 0.95 and unique_count > 10:
    task = "regress"
else:
    task = "classify"

print(f"      Auto-detected task type: {task.upper()} ({unique_count} unique values)", flush=True)
if task == "classify":
    print("      Unique class labels detected:")
    for val in y.unique():
        print(f"         • '{val}'")


# ---------- Prepare data ----------
X = pd.DataFrame(X).apply(pd.to_numeric, errors="coerce").fillna(0).values.astype("float32")

# scaler = StandardScaler()
# X = scaler.fit_transform(X)

if task == "classify":
    y = y.astype("category").cat.codes.astype("int32")
else:
    y = numeric_y.fillna(0).values.astype("float32")

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print(f"      Train size: {X_train.shape}, Test size: {X_test.shape}")

# ---------- Train model ----------
print("[7/8] Training model...", flush=True)

if task == "classify":
    num_classes = len(np.unique(y_train))
    print(f"      Detected {num_classes} unique classes")

    model = Sequential([
        Dense(128, activation="relu", input_shape=(X_train.shape[1],)),
        Dropout(args.dropout),
        Dense(64, activation="relu"),
        Dense(num_classes, activation="softmax" if num_classes > 2 else "sigmoid")
    ])

    model.compile(
        optimizer="adam",
        loss="sparse_categorical_crossentropy" if num_classes > 2 else "binary_crossentropy",
        metrics=["accuracy"]
    )

    model.fit(X_train, y_train, validation_data=(X_test, y_test),
              epochs=args.epochs, batch_size=args.batch_size, verbose=0)
    loss, acc = model.evaluate(X_test, y_test, verbose=0)
    print(f"[Result] Classification accuracy: {acc:.4f}", flush=True)

else:
    model = Sequential([
        Dense(128, activation="relu", input_shape=(X_train.shape[1],)),
        Dropout(args.dropout),
        Dense(64, activation="relu"),
        Dense(1)
    ])
    model.compile(optimizer="adam", loss="mse")

    model.fit(X_train, y_train, validation_data=(X_test, y_test),
              epochs=args.epochs, batch_size=args.batch_size, verbose=0)
    preds = model.predict(X_test)
    if preds.ndim > 1:
        preds = preds.ravel()
    r2 = r2_score(y_test, preds)
    print(f"[Result] R² score: {r2:.4f}", flush=True)

# ---------- Save model ----------
# Save to same folder as this script (ml_tabular.py)
save_dir = Path(__file__).resolve().parent / "models"

os.makedirs(save_dir, exist_ok=True)
model.save(save_dir / f"model_{args.target}_{task}.keras")
print(f"[8/8] ✅ Model saved to {save_dir / f'model_{args.target}_{task}.keras'}")
